#!/bin/bash

set +x

: "${BENCHMARK:?BENCHMARK not set}"
: "${BZ_PER_GPU:?BZ_PER_GPU not set}"
: "${NODE_NUM:?NODE_NUM not set}"
: "${GPU_NUM:?GPU_NUM not set}"
: "${MIXED_PRECISION:? MIXED_PRECISION}"
if [ ${MIXED_PRECISION} == "FP16" ];then
    MIXED_PRECISION_FLAG="--use_mixed_precision";
fi

export WDL_DATA_PATH=${NEW_CRITEO_MOUNT}/wdl_data
export DCN_DATA_PATH=${NEW_CRITEO_MOUNT}/dcn_data

srun --ntasks="${NODE_NUM}" bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3"

if [[ ${BENCHMARK} == "dlrm" && $NODE_NUM -gt 1 ]]; then
    srun --ntasks="${NODE_NUM}" --container-image="${CONT}" --container-mounts="${MOUNTS}" --network sharp bash -cx " \
    cd /raid/datasets/criteo/mlperf/40m.limit_preshuffled/ && \
    numactl --membind=1,3,5,7 python3 /workdir/samples/dlrm/dgx_a100_14x8x640.py"
elif [[ ${BENCHMARK} == "dlrm" && $NODE_NUM -eq 1 ]]; then
    srun --ntasks="${NODE_NUM}" --container-image="${CONT}" --container-mounts="${MOUNTS}" --network sharp bash -cx " \
    cd /raid/datasets/criteo/mlperf/40m.limit_preshuffled/ && \
    numactl --membind=1,3,5,7 python3 /workdir/samples/dlrm/dgx_a100.py"
else
    srun --ntasks="${NODE_NUM}" --container-image="${CONT}" --container-mounts="${MOUNTS}" $USE_SHARP bash -cx " \
    cd /workdir/ci/benchmark/train_benchmark && \
    python3 ./benchmark_train.py \
        --benchmark ${BENCHMARK} \
        --batchsize_per_gpu ${BZ_PER_GPU} \
        --node_num ${NODE_NUM} \
        --gpu_num ${GPU_NUM} \
        ${MIXED_PRECISION_FLAG}"
fi
